from pyspark import SparkConf
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import functions as F
import csv
from itertools import islice

def load_test_data(data_path):
    file = csv.reader(open(data_path, "r", encoding="utf-8"))
    user_id_list = []
    for line in islice(file, 1, None):
        user_id_list.append(int(line[0]))
    return user_id_list

if __name__ == '__main__':
    conf = SparkConf().setAppName('collaborativeFiltering').setMaster('local')
    spark = SparkSession.builder.config(conf=conf).getOrCreate()

    ratingResourcesPath = "./dataset/new_dataset.csv"
    ratingSamples = spark.read.format('csv').option('header', 'true').load(ratingResourcesPath) \
        .withColumn("userIdInt", F.col("user_id").cast(IntegerType())) \
        .withColumn("itemIdInt", F.col("item_id").cast(IntegerType())) \
        .withColumn("ratingFloat", F.col("rating").cast(FloatType()))
    training, test = ratingSamples.randomSplit((0.8, 0.2))
    # Build the recommendation model using ALS on the training data
    # Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
    als = ALS(regParam=0.01, maxIter=5, userCol='userIdInt', itemCol='itemIdInt', ratingCol='ratingFloat',
              coldStartStrategy='drop')
    model = als.fit(training)
    # Evaluate the model by computing the RMSE on the test data
    predictions = model.transform(test)
    model.itemFactors.show(10, truncate=False)
    model.userFactors.show(10, truncate=False)
    evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='ratingFloat', metricName='rmse')
    rmse = evaluator.evaluate(predictions)
    print("Root-mean-square error = {}".format(rmse))




    recommendations = model.recommendForAllUsers(10)



    test_user_id_list = load_test_data("./dataset/book_test_dataset.csv")
    file = open("submission_lfm_2.csv", "w", encoding="utf-8", newline="")
    csv_writer = csv.writer(file)
    csv_writer.writerow(["user_id", "item_id"])
    count = 0
    for uid in test_user_id_list:
        res = recommendations.where(recommendations.userIdInt == uid).collect()
        res = res[0][1]
        for itemid in res:
            csv_writer.writerow([str(uid), str(itemid)])
        count += 1
        if count % 20 == 0:
            print("当前推荐进度: {:.2f}%".format(count / len(test_user_id_list) * 100))

    file.close()


    spark.stop()

